[PyTorch] [torch.compile] transformer_engine.pytorch.autocast suport inside torch.compile#2759
[PyTorch] [torch.compile] transformer_engine.pytorch.autocast suport inside torch.compile#2759pggPL wants to merge 14 commits intoNVIDIA:mainfrom
Conversation
Move FP8 global state onto an instance so Dynamo can trace autocast state updates, explicitly reject DelayedScaling under torch.compile, and add toy compile tests that keep TE forward/backward opaque while covering supported recipes. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop the standalone global dict and dataclass mutation experiments now that the torch.compile regression coverage lives in the focused autocast test file. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Use compiler constant-result wrappers for support checks and rename the module-level FP8 singleton to `_FP8_GLOBAL_STATE` for clearer semantics. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
338ddae to
b5d46fd
Compare
Restore the FP8 naming and remove extra state access helpers so the torch.compile changes stay focused on the instance-backed global state. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop stale availability fields from FP8GlobalState now that support checks use module-level cached results instead of manager state. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Resolve conflicts in the FP8 torch.compile changes while preserving the upstream updates in graph.py and module/base.py. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enables Key issues identified (not yet discussed in prior threads):
Previously-flagged issues still open: Confidence Score: 2/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["te.autocast(recipe=R, enabled=E)"] --> B{enabled?}
B -- yes --> C["check_recipe_support(recipe)"]
C --> D{is_compiling AND\nDelayedScaling?}
D -- yes --> E["RuntimeError ✗"]
D -- no --> F["save fp8_state tuple\nfrom quantization_state.*"]
B -- no --> F
F --> G["autocast_enter(enabled, recipe, ...)"]
G --> H{recipe is None?}
H -- yes --> I["get_default_fp8_recipe()\nassert not is_compiling ⚠️"]
H -- no --> J["quantization_state.autocast_arguments[key] = recipe,group\nquantization_state.fp8_* = new values"]
I --> J
J --> K["yield — user code runs"]
K --> L["finally: restore quantization_state.*\nfrom saved fp8_state tuple"]
L --> M["autocast_exit: depth--, reduce amaxes if needed"]
subgraph "support cache (module-level globals)"
N["_FP8_SUPPORT / _MXFP8_SUPPORT\n_NVFP4_SUPPORT / _FP8_BLOCK_SCALING_SUPPORT"]
O["check_fp8_support() etc.\n@assume_constant_result"] --> N
end
subgraph "FP8GlobalState singleton"
P["quantization_state\n.fp8_enabled / .fp8_recipe\n.autocast_depth / .is_first_fp8_module\n.global_amax_buffer / .autocast_arguments\n..."]
end
J -.-> P
L -.-> P
Reviews (5): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
Replace custom-op-based ToyLinear with a minimal version using F.linear. Add test_autocast_sanity (parametrized over all recipes including NVFP4) and test_autocast_nested_sanity with CustomRecipes. Both verify fullgraph=True compilation without graph breaks. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
for more information, see https://pre-commit.ci
Verify that te.autocast(recipe=DelayedScaling(), enabled=True) raises a clear RuntimeError when used inside torch.compile. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
Use str(recipe) for content-based recipe keying (avoids unbounded growth when identical recipes are constructed inline) and id(group) for process group identity (same semantics as the old hash(group) which was id-based). Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
Replace custom_op-based approach with torch.library.define/impl/register_fake using get_opaque_type_name() in the schema, which allows Inductor to properly handle opaque value types. Add ToyQuantizer as an opaque value-type wrapper around Float8CurrentScalingQuantizer with proper __eq__/__hash__/__fx_repr__. test_autocast_nested_custom validates that nested te.autocast with 3 distinct CustomRecipe instances passes the correct quantizers in both forward and backward. test_autocast_sanity is a smoke test for all hardware-supported built-in recipes. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
for more information, see https://pre-commit.ci
| from torch._opaque_base import OpaqueBaseMeta | ||
| from torch._library.opaque_object import ( | ||
| get_opaque_type_name, | ||
| register_opaque_type, | ||
| MemberType, | ||
| ) |
There was a problem hiding this comment.
Private PyTorch internal APIs imported without a stability guarantee
torch._opaque_base and torch._library.opaque_object are private modules (underscore-prefixed). PyTorch treats them as implementation details — their names, locations, and interfaces can change between minor releases without deprecation warnings. A future PyTorch upgrade could silently break every test in this file with an ImportError or AttributeError.
If there is an equivalent public API (e.g., something under torch.library without the underscore), it should be used instead. If no public API exists yet, the import should at minimum be wrapped in a version guard and a comment explaining why the private API is used:
# torch._opaque_base / torch._library.opaque_object are private APIs as of PyTorch 2.x.
# These imports will need to be updated if/when a public surface is provided.
try:
from torch._opaque_base import OpaqueBaseMeta
from torch._library.opaque_object import get_opaque_type_name, register_opaque_type, MemberType
except ImportError:
pytest.skip("Private torch opaque-type APIs not available in this PyTorch build")| @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) | ||
| @pytest.mark.parametrize("fp8_recipe", _all_recipes, ids=lambda r: type(r).__name__) | ||
| def test_autocast_sanity(fp8_recipe): | ||
| """Smoke test: torch.nn.Linear inside a single te.autocast with each | ||
| built-in recipe. Forward + backward under torch.compile(fullgraph=True).""" | ||
| dtype = torch.bfloat16 | ||
| device = "cuda" | ||
|
|
||
| model = torch.nn.Linear(32, 32, dtype=dtype, device=device) | ||
| inp = torch.randn(8, 32, dtype=dtype, device=device, requires_grad=True) | ||
|
|
||
| def fn(inp): | ||
| with te.autocast(recipe=fp8_recipe): | ||
| return model(inp) | ||
|
|
||
| torch._dynamo.reset() | ||
| compiled = torch.compile(fn, fullgraph=True) | ||
|
|
||
| out = compiled(inp) | ||
| out.sum().backward() |
There was a problem hiding this comment.
test_autocast_sanity uses torch.nn.Linear — no TE quantization is exercised
torch.nn.Linear is a standard PyTorch module that is completely unaware of TE's FP8 machinery. Wrapping it in te.autocast(recipe=fp8_recipe) does not cause any quantization to occur; the forward and backward passes run entirely in BF16 regardless of which recipe is passed. This means that despite the parametrisation over Float8CurrentScaling, Float8BlockScaling, MXFP8BlockScaling, and NVFP4BlockScaling, all four test variants exercise identical code paths and provide no evidence that any of those recipes actually work under torch.compile.
ToyLinear (defined in this same file specifically for this purpose) should be used instead. It calls into TE's BasicLinear functional ops and respects the active autocast recipe, so it would actually exercise the quantization code paths:
model = ToyLinear(32, 32, dtype=dtype, device=device)If ToyLinear cannot be driven with all four recipes today, those recipe variants should receive their own skip guards with a comment explaining the gap.
|
/te-ci pytorch |
Description
Enable
torch.compile(fullgraph=True)for FP8 autocast by moving compile-visible mutable state off class attributes, avoiding tracing through support checks, and adding test.Type of change
Changes
Please list the changes introduced in this PR:
clsattribute writes to a dataclass-backed singleton object, becausetorch.compiledoes not support writes directly to class attributes.lru_cache-based support checks with explicit module-level caches and mark the wrapper functions with@torch.compiler.assume_constant_resultsotorch.compiledoes not trace intocheck_*_support().torch.compilecoverage for FP8 autocast using a custom test module; the test is more involved because there is currently no simple TE layer that supports both FP8 andtorch.compile.DelayedScalingexplicitly unsupported undertorch.compileand raise a clear error for that case.Checklist:
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: